import os
import time
import json
import re
import numpy as np
import importlib
from typing import List, Dict, Optional
from evaluate.data_loader import load_truth_file
from evaluate.gate_counter import count_gates


def load_method(method_name: str):
    module = importlib.import_module(
        f'method.{method_name.replace(" ", "_")}')
    return module if hasattr(module, 'find_expressions') else None


def save_single_result(
    results: Dict, 
    output_dir: str, 
    method_name: str,
    data_file: str, 
    file_index: int
):
    """Save single experiment result to JSON file"""
    rel_path = os.path.relpath(data_file, start='data')
    output_subdir = os.path.join(output_dir, os.path.splitext(rel_path)[0])
    os.makedirs(output_subdir, exist_ok=True)

    result = {
        'data_file': data_file,
        'expression': results['expressions'][file_index],
        'train_bit_acc': float(results['train_accuracy'][file_index]),
        'test_bit_acc': float(results['test_accuracy'][file_index]),
        'train_sample_acc': float(results['train_exact_match'][file_index]),
        'test_sample_acc': float(results['test_exact_match'][file_index]),
        'num_gates': float(results['num_gates'][file_index]) if not np.isnan(results['num_gates'][file_index]) else None,
        'execution_time': float(results['execution_time'][file_index]),
        'timestamp': time.strftime("%Y-%m-%d %H:%M:%S"),
        'progress': f"{file_index + 1}/{len(results['expressions'])}"
    }

    output_path = os.path.join(output_subdir, f"{method_name}.json")
    with open(output_path, 'w') as f:
        json.dump(result, f, indent=2, ensure_ascii=False)

    print(f" Results saved to: {output_path}")


def run_experiment(
    method_module,
    data_files: List[str],
    split: float,
    num_runs: int,
    output_dir: str,
    method_name: str,
    input_size: Optional[int] = None,
    output_size: Optional[int] = None,
    operators: Optional[List[str]] = None
) -> Dict:
    """Run experiment with immediate result saving"""
    
    # Initialize results structure
    results = {
        'expressions': [],
        'train_accuracy': [],
        'test_accuracy': [],
        'num_gates': [],
        'execution_time': [],
        'train_exact_match': [],
        'test_exact_match': []
    }
    
    # Initialize result lists for each data file
    for _ in data_files:
        results['expressions'].append([])
        results['train_accuracy'].append(0.0)
        results['test_accuracy'].append(0.0)
        results['num_gates'].append(np.nan)
        results['execution_time'].append(0.0)
        results['train_exact_match'].append(False)
        results['test_exact_match'].append(False)

    def _handle_error_case(i, test_file, start_time, error_msg="Method error"):
        """Handle error cases with consistent result setting"""
        print(f"  {error_msg} for {test_file}")
        results['expressions'][i] = ["FAIL"]
        results['train_accuracy'][i] = 0.0
        results['test_accuracy'][i] = 0.0
        results['num_gates'][i] = np.nan
        results['execution_time'][i] = time.time() - start_time
        results['train_exact_match'][i] = 0.0
        results['test_exact_match'][i] = 0.0
        save_single_result(results, output_dir, method_name, test_file, i)

    for i, test_file in enumerate(data_files):
        start_time = time.time()
        X, Y = load_truth_file(test_file)

        # Call method with appropriate parameters
        if getattr(method_module, '__name__', '').endswith('nn_boolformer'):
            # Extract noise from path for nn_boolformer
            m = re.search(r'noise_([0-9.]+)', test_file)
            noise = float(m.group(1)) if m else 0.0
            result = method_module.find_expressions(X, Y, split=split, noise=noise, allowed_operators=operators)
        else:
            result = method_module.find_expressions(X, Y, split=split)
            
        if result is None:
            _handle_error_case(i, test_file, start_time)
            continue

        expressions, accuracies, _ = result

        # Detect all "0" result case (applies to all methods)
        if expressions and len(expressions) > 0 and all(str(expr) == "0" for expr in expressions):
            print("  Detected all '0' results, terminate subsequent tests at current scale and larger scales")
            print(f"   Debug: expressions = {expressions}")
            print(f"   File: {test_file}")
            _handle_error_case(i, test_file, start_time, "All '0' results detected")

        # Filter valid expressions
        valid_expr = [e for e in expressions if e not in [None, "FAIL", 0]]

        # Count gates with special handling for ABC method
        if method_name == 'abc' and hasattr(method_module, 'get_abc_stdouts'):
            abc_stdouts = method_module.get_abc_stdouts()
            num_gates = count_gates(valid_expr, X.shape[1], abc_stdouts)
        else:
            num_gates = count_gates(valid_expr, X.shape[1])
            
        # Store results
        results['expressions'][i] = expressions
        results['train_accuracy'][i] = np.mean([acc[0] for acc in accuracies])
        results['test_accuracy'][i] = np.mean([acc[1] for acc in accuracies])
        results['num_gates'][i] = num_gates
        results['execution_time'][i] = time.time() - start_time
        results['train_exact_match'][i] = np.mean([acc[2] if len(acc) > 2 else 0.0 for acc in accuracies])
        results['test_exact_match'][i] = np.mean([acc[3] if len(acc) > 3 else 0.0 for acc in accuracies])
        # Print results
        print(f"  File: {test_file}")
        
        # Display expression
        expr_display = expressions[0] if len(expressions) == 1 else expressions
        print(f"  Expression: {expr_display}")
            
        print(f"  Train Bit Acc: {results['train_accuracy'][i]:.3f}, Test Bit Acc: {results['test_accuracy'][i]:.3f}")
        print(f"  Train Sample Acc: {results['train_exact_match'][i]:.3f}, Test Sample Acc: {results['test_exact_match'][i]:.3f}\n")
        
        # Save results
        save_single_result(results, output_dir, method_name, test_file, i)

    return results
